{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(171200:281473112917552,MainProcess):2021-03-16-15:03:28.981.597 [mindspore/_check_version.py:207] MindSpore version 1.1.1 and \"te\" wheel package version 1.0 does not match, reference to the match info on: https://www.mindspore.cn/install\n",
      "MindSpore version 1.1.1 and \"topi\" wheel package version 0.6.0 does not match, reference to the match info on: https://www.mindspore.cn/install\n",
      "[WARNING] ME(171200:281473112917552,MainProcess):2021-03-16-15:03:29.503.486 [mindspore/ops/operations/array_ops.py:2302] WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: 'ControlDepend' is deprecated from version 1.1 and will be removed in a future version, use 'Depend' instead.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "from src.config import cfg\n",
    "from src.dataset import create_dataset\n",
    "from src.seq2seq import Seq2Seq, InferCell\n",
    "from src.seq2seq import Seq2Seq, WithLossCell\n",
    "\n",
    "from mindspore import Tensor, nn, Model, context\n",
    "from mindspore.train.serialization import load_param_into_net, load_checkpoint\n",
    "from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint, TimeMonitor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target='Ascend', device_id=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_train = create_dataset(cfg.dataset_path, cfg.batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "network = Seq2Seq(cfg)\n",
    "network = WithLossCell(network, cfg)\n",
    "optimizer = nn.Adam(network.trainable_params(), learning_rate=cfg.learning_rate, beta1=0.9, beta2=0.98)\n",
    "model = Model(network, optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 step: 125, loss is 2.5471632\n",
      "epoch time: 72593.791 ms, per step time: 580.750 ms\n",
      "epoch: 2 step: 125, loss is 2.5645504\n",
      "epoch time: 11230.366 ms, per step time: 89.843 ms\n",
      "epoch: 3 step: 125, loss is 2.3836899\n",
      "epoch time: 11235.888 ms, per step time: 89.887 ms\n",
      "epoch: 4 step: 125, loss is 2.279439\n",
      "epoch time: 11229.956 ms, per step time: 89.840 ms\n",
      "epoch: 5 step: 125, loss is 1.5323433\n",
      "epoch time: 11232.835 ms, per step time: 89.863 ms\n",
      "epoch: 6 step: 125, loss is 1.3322783\n",
      "epoch time: 11236.202 ms, per step time: 89.890 ms\n",
      "epoch: 7 step: 125, loss is 0.8172446\n",
      "epoch time: 11236.513 ms, per step time: 89.892 ms\n",
      "epoch: 8 step: 125, loss is 0.6874578\n",
      "epoch time: 11227.472 ms, per step time: 89.820 ms\n",
      "epoch: 9 step: 125, loss is 0.46486482\n",
      "epoch time: 11228.080 ms, per step time: 89.825 ms\n",
      "epoch: 10 step: 125, loss is 0.39268598\n",
      "epoch time: 11235.647 ms, per step time: 89.885 ms\n",
      "epoch: 11 step: 125, loss is 0.22333553\n",
      "epoch time: 11271.179 ms, per step time: 90.169 ms\n",
      "epoch: 12 step: 125, loss is 0.17082311\n",
      "epoch time: 11265.679 ms, per step time: 90.125 ms\n",
      "epoch: 13 step: 125, loss is 0.16510005\n",
      "epoch time: 11263.505 ms, per step time: 90.108 ms\n",
      "epoch: 14 step: 125, loss is 0.062068082\n",
      "epoch time: 11265.564 ms, per step time: 90.125 ms\n",
      "epoch: 15 step: 125, loss is 0.13220823\n",
      "epoch time: 11263.859 ms, per step time: 90.111 ms\n"
     ]
    }
   ],
   "source": [
    "loss_cb = LossMonitor()\n",
    "config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max)\n",
    "ckpoint_cb = ModelCheckpoint(prefix=\"gru\", directory=cfg.ckpt_save_path, config=config_ck)\n",
    "time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())\n",
    "callbacks = [time_cb, ckpoint_cb, loss_cb]\n",
    "\n",
    "model.train(cfg.num_epochs, ds_train, callbacks=callbacks, dataset_sink_mode=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "rank = 0\n",
    "device_num = 1\n",
    "ds_eval= create_dataset(cfg.dataset_path, cfg.eval_batch_size, is_training=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "network = Seq2Seq(cfg,is_train=False)\n",
    "network = InferCell(network, cfg)\n",
    "network.set_train(False)\n",
    "parameter_dict = load_checkpoint(cfg.checkpoint_path)\n",
    "load_param_into_net(network, parameter_dict)\n",
    "model = Model(network)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "English: who likes beans ? \n",
      "expect Chinese: 谁喜欢豆子？\n",
      "predict Chinese: 谁喜欢豆子？\n",
      " \n",
      "English: who built it ? \n",
      "expect Chinese: 这是谁建的？\n",
      "predict Chinese: 这是谁建的？\n",
      " \n",
      "English: tom is very quiet . \n",
      "expect Chinese: 汤姆很安静。\n",
      "predict Chinese: 汤姆人很好。\n",
      " \n",
      "English: are you finished ? \n",
      "expect Chinese: 你结束了吗？\n",
      "predict Chinese: 你结束了吗？\n",
      " \n",
      "English: i don t get it . \n",
      "expect Chinese: 我不懂。\n",
      "predict Chinese: 我无所谓。\n",
      " \n",
      "English: i understand . \n",
      "expect Chinese: 我明白了。\n",
      "predict Chinese: 我明白了。\n",
      " \n",
      "English: you made me laugh . \n",
      "expect Chinese: 我被你逗乐了。\n",
      "predict Chinese: 我被你逗乐了。\n",
      " \n",
      "English: excuse me . \n",
      "expect Chinese: 对不起。\n",
      "predict Chinese: 对不起。\n",
      " \n",
      "English: it s business . \n",
      "expect Chinese: 公事公办。\n",
      "predict Chinese: 公事公办。\n",
      " \n",
      "English: she is graceful . \n",
      "expect Chinese: 她举止优雅。\n",
      "predict Chinese: 她举止优雅。\n",
      " \n",
      "English: he s not home . \n",
      "expect Chinese: 他不在家。\n",
      "predict Chinese: 他不在家。\n",
      " \n",
      "English: it s very big . \n",
      "expect Chinese: 它很大。\n",
      "predict Chinese: 它很大。\n",
      " \n",
      "English: what s that ? \n",
      "expect Chinese: 那是什么？\n",
      "predict Chinese: 那是什么？\n",
      " \n",
      "English: tom hit a triple . \n",
      "expect Chinese: 汤姆击出三垒安打。\n",
      "predict Chinese: 汤姆击出三垒安打。\n",
      " \n",
      "English: i booked a seat . \n",
      "expect Chinese: 我订了一个位子。\n",
      "predict Chinese: 我订了一个位子。\n",
      " \n",
      "English: how about you ? \n",
      "expect Chinese: 您呢？\n",
      "predict Chinese: 你们呢？\n",
      " \n",
      "English: life goes on . \n",
      "expect Chinese: 人生会继续。\n",
      "predict Chinese: 人生会继续。\n",
      " \n",
      "English: do you like music ? \n",
      "expect Chinese: 你爱音乐吗？\n",
      "predict Chinese: 你喜欢音乐吗？\n",
      " \n",
      "English: i bathe every day . \n",
      "expect Chinese: 我每天都洗澡。\n",
      "predict Chinese: 我每天都洗澡。\n",
      " \n",
      "English: earth is a planet . \n",
      "expect Chinese: 地球是一个行星。\n",
      "predict Chinese: 地球是一个行星。\n",
      " \n"
     ]
    }
   ],
   "source": [
    "with open(os.path.join(cfg.dataset_path,\"en_vocab.txt\"), 'r', encoding='utf-8') as f:\n",
    "    data = f.read()\n",
    "en_vocab = list(data.split('\\n'))\n",
    "\n",
    "with open(os.path.join(cfg.dataset_path,\"ch_vocab.txt\"), 'r', encoding='utf-8') as f:\n",
    "    data = f.read()\n",
    "ch_vocab = list(data.split('\\n'))\n",
    "\n",
    "for data in ds_eval.create_dict_iterator():\n",
    "    en_data=''\n",
    "    ch_data=''\n",
    "    for x in data['encoder_data'][0].asnumpy():\n",
    "        if x == 0:\n",
    "            break\n",
    "        en_data += en_vocab[x]\n",
    "        en_data += ' '\n",
    "    for x in data['decoder_data'][0].asnumpy():\n",
    "        if x == 0:\n",
    "            break\n",
    "        if x == 1:\n",
    "            continue\n",
    "        ch_data += ch_vocab[x]\n",
    "    output = network(data['encoder_data'],data['decoder_data'])\n",
    "    print('English:', en_data)\n",
    "    print('expect Chinese:', ch_data)\n",
    "    out =''\n",
    "    for x in output[0].asnumpy():\n",
    "        if x == 0:\n",
    "            break\n",
    "        out += ch_vocab[x]\n",
    "    print('predict Chinese:', out)\n",
    "    print(' ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
